﻿using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Linq.Expressions;
using System.Reflection;
using System.Diagnostics;
using System.Collections;

namespace QueryInterception
{


    /// <remarks>http://stackoverflow.com/a/9120931/57883</remarks>
    public class TypeChangeVisitor : ExpressionVisitor
    {
        readonly IDictionary<Type, Type> _typeReplacements;

        int visitStack = 0;

        public TypeChangeVisitor(IDictionary<Type, Type> typeReplacements)
        {
            _typeReplacements = typeReplacements;
            var addItems = new Dictionary<Type, Type>();
            foreach (var item in typeReplacements.Keys)
            {
                if (item.IsInterface == false)
                    continue;
                var interfaces = item.GetInterfaces();
                foreach (var i in interfaces)
                {
                    if (_typeReplacements.ContainsKey(i) == false)
                        addItems.Add(i, _typeReplacements[item]);
                }
            }
            foreach (var item in addItems)
                _typeReplacements.Add(item.Key, item.Value);

        }

        IEnumerable<Type> TransformMethodArgs(MethodBase method)
        {
            //if(method.IsGenericMethod)
            //only generic methods should land here.
            foreach (var t in method.GetGenericArguments())
            {
                yield return VisitType(t);
            }


        }
        bool NeedsTypeChange(Type t)
        {
            var hasBadType = _typeReplacements.Keys.Any(k => t.FullName.Contains(k.FullName));

            return hasBadType;
        }
        Type VisitType(Type t)
        {
            if (_typeReplacements.ContainsKey(t))
            {
                return _typeReplacements[t];
            }
            if (t.IsGenericType & t.GetGenericArguments().Any(NeedsTypeChange))
            {
                var types = t.GetGenericArguments().Select(VisitType).ToArray();
                var newType = t.GetGenericTypeDefinition().MakeGenericType(types);
                Debug.Assert(NeedsTypeChange(newType) == false);
                return newType;
            }
            Debug.Assert(NeedsTypeChange(t) == false);
            return t;
        }
        NewExpression TransformNewCall(NewExpression node)
        {
            Debug.Assert(node.Constructor != null);




            var argParams = node.Arguments.Select(n => Visit(n));
            Debug.Assert(argParams.Any(a => NeedsTypeChange(a.Type)) == false);
            var constructor = node.Constructor;
            //generic class constructor
            if (constructor.DeclaringType.IsGenericType && constructor.DeclaringType.GetGenericArguments().Any(NeedsTypeChange))
            {

                var newType = VisitType(constructor.DeclaringType);

                Debug.Assert(NeedsTypeChange(newType) == false);
                var constructorTypes = constructor.GetParameters().Select(s => s.ParameterType).Select(VisitType).ToArray();
                Debug.Assert(constructorTypes.Any(NeedsTypeChange) == false);
                constructor = newType.GetConstructor(constructorTypes);
                Debug.Assert(NeedsTypeChange(constructor.DeclaringType) == false);
            }
            var members = from fMember in node.Members
                          join nMember in constructor.DeclaringType.GetMembers()
                          on fMember.Name equals nMember.Name
                          select nMember;

            var membersTransformed = members.ToArray();

            //Safe only because the from type is assumed to be an interface, and a new would be an anonymous type
            var visited = NewExpression.New(constructor, argParams, membersTransformed);
            Debug.Assert(visited.Members.Count == node.Members.Count);
            Debug.Assert(NeedsTypeChange(visited.Type) == false);
            return visited;

        }

        MethodCallExpression TransformMethodCall(MethodCallExpression node)
        {
            Debug.Assert(node.Method != null);
            var argTypes = TransformMethodArgs(node.Method).ToArray();
            Debug.Assert(argTypes.Any(NeedsTypeChange) == false);
            var argParams = node.Arguments.Select(n => Visit(n)).ToArray();
            Debug.Assert(argParams.Any(a => NeedsTypeChange(a.Type)) == false);
            var methodInfo = node.Method.GetGenericMethodDefinition().MakeGenericMethod(argTypes);
            Debug.Assert(NeedsTypeChange(methodInfo.DeclaringType) == false);

            var visited = MethodCallExpression.Call(node.Object, methodInfo, argParams);
            Debug.Assert(NeedsTypeChange(visited.Type) == false);
            return visited;
        }
        protected override Expression VisitMethodCall(MethodCallExpression node)
        {
            Expression visited;
            if (node.Method != null && node.Method.ReturnType != null && NeedsTypeChange(node.Method.ReturnType))
            {
                var transformed = TransformMethodCall(node);
                Debug.WriteLine("Transformed methodcall");
                visited = base.VisitMethodCall(transformed);
            }
            else

                if (node.Method != null && node.Arguments != null && node.Arguments.Any(t => NeedsTypeChange(t.Type)))
                {

                    var transformed = TransformMethodCall(node);
                    Debug.WriteLine("Transformed methodcall");
                    visited = base.VisitMethodCall(transformed);
                    //node.Method.ReturnType.IsAssignableFrom(node.Arguments[0].Type)
                    //node.Method.ReturnType.GetGenericArguments()
                    //visited = node.Arguments[0];

                }
                else

                    visited = base.VisitMethodCall(node);
            return visited;
        }

        /// <remarks>http://stackoverflow.com/a/9120931/57883</remarks>
        protected override Expression VisitUnary(UnaryExpression node)
        {
            Expression visited;
            if (NeedsTypeChange(node.Type))
            {
                var operand = Visit(node.Operand);
                var newType = VisitType(node.Type);
                visited = Expression.MakeUnary(node.NodeType, operand, newType);
            }
            else visited = base.VisitUnary(node);
            //if (node.NodeType == ExpressionType.Convert 
            //&& node.Type.IsAssignableFrom(node.Operand.Type))
            //{
            //    return base.Visit(node.Operand);
            //}
            Debug.Assert(NeedsTypeChange(visited.Type) == false);
            return visited;


        }

        protected override Expression VisitInvocation(InvocationExpression node)
        {
            var call = base.VisitInvocation(node);
            return call;
        }

        public NewExpression LastNew { get; private set; }
        public NewExpression LastNewResult { get; private set; }

        protected override Expression VisitNew(NewExpression node)
        {
            LastNew = node;
            Debug.WriteLine("Transforming NewExpression");
            var value = ExpressionWriter.WriteToString(node);
            Debug.WriteLine(value);
            Debug.WriteLine(string.Empty);
            Expression visited;
            if (NeedsTypeChange(node.Type) || node.Arguments.Any(a => NeedsTypeChange(a.Type)))
            {
                var transformed = TransformNewCall(node);

                Debug.Assert(NeedsTypeChange(transformed.Type) == false);

                visited = base.VisitNew(transformed);
            }
            else
                visited = base.VisitNew(node);
            Debug.WriteLine("Transformed to");
            Debug.WriteLine(ExpressionWriter.WriteToString(visited));
            Debug.WriteLine(string.Empty);
            LastNewResult = (NewExpression)visited;
            return visited;
        }
        protected override Expression VisitConstant(ConstantExpression node)
        {
            Expression visited;
            if (NeedsTypeChange(node.Type))
            {
                if (node.Value == null)
                    return ConstantExpression.Constant(null, VisitType(node.Type));
                var valueType = node.Value.GetType();
                //var newType=VisitType(node.Type);
                if (valueType.IsArray)
                {


                    var value = node.Value;
                    if (NeedsTypeChange(valueType)) //array is of bad type even though elements may not be
                    {
                        var elementType = valueType.GetElementType();
                        var newElementType = VisitType(elementType);

                        var oldArray = (node.Value as Array);
                        //var oArray = (object[])node.Value;

                        var newArray = Array.CreateInstance(newElementType, oldArray.Length);
                        oldArray.CopyTo(newArray, 0);
                        value = newArray;
                    }
                    else
                    {
                        //types in the array need changed, but the array itself does not
                        //should never happen
                        Debug.Assert(false);
                    }

                    var transformed = ConstantExpression.Constant(value);
                    visited = base.VisitConstant(transformed);



                }
                else

                    visited = ConstantExpression.Constant(node.Value);
            }
            else visited = base.VisitConstant(node);
            Debug.Assert(NeedsTypeChange(visited.Type) == false);
            return visited;
        }
        public override Expression Visit(Expression node)
        { // general substitutions (for example, parameter swaps)
            if (node == null)
                return null;
            Expression found = null;

            visitStack++;

            found = base.Visit(node);

            Debug.Assert(found == null || NeedsTypeChange(found.Type) == false);

            visitStack--;

            return found;

        }
        protected override Expression VisitBinary(BinaryExpression node)
        {
            var newBinary = base.VisitBinary(node);
            return newBinary;
        }

        /// <summary>
        /// In a given query the params must be the same instance, not just the same name/type
        /// </summary>
        /// <remarks>
        /// The parameter xxx was not bound in the specified LINQ to Entities query expression
        /// http://social.msdn.microsoft.com/Forums/en/adodotnetentityframework/thread/8c2b0b1c-01bb-4de2-af46-0b4ea866cf8f
        /// </remarks>
        readonly Dictionary<ParameterExpression, ParameterExpression> paramMappings = new Dictionary<ParameterExpression, ParameterExpression>();


        protected override Expression VisitLambda<T>(Expression<T> node)
        {

            Expression visited;


            if (NeedsTypeChange(node.ReturnType) || node.Parameters.Any(p => NeedsTypeChange(p.Type)))
            {
                var visitedBody = Visit(node.Body);
                Debug.Assert(NeedsTypeChange(visitedBody.Type) == false);

                IList<ParameterExpression> transformedParams = new List<ParameterExpression>();
                foreach (var p in node.Parameters)
                {
                    var transformedParam = VisitParameter(p);
                    Debug.Assert(transformedParam is ParameterExpression);
                    transformedParams.Add((ParameterExpression)transformedParam);

                }

                Debug.Assert(transformedParams.Any(t => NeedsTypeChange(t.Type)) == false);

                var transformed = Expression.Lambda(visitedBody, transformedParams.ToArray());
                Debug.Assert(NeedsTypeChange(transformed.Type) == false);

                Debug.Assert(NeedsTypeChange(transformed.ReturnType) == false);
                if (transformed is Expression<T>)
                {
                    var transformedCasted = transformed as Expression<T>;
                    visited = base.VisitLambda<T>(transformedCasted);
                }
                else
                    visited = transformed;
                //return base.VisitLambda<T>(newLambda);
            }
            else
                visited = base.VisitLambda<T>(node);
            Debug.Assert(NeedsTypeChange(visited.Type) == false);



            return visited;
        }
        protected override Expression VisitParameter(ParameterExpression node)
        {
            Expression visited = node;

            if (NeedsTypeChange(node.Type))
            {
                ParameterExpression visitedParam = node;
                if (paramMappings.ContainsKey(node))
                {
                    visited = base.VisitParameter(paramMappings[node]);
                }
                else
                {
                    var newType = VisitType(visitedParam.Type);
                    var newParam = Expression.Parameter(newType, node.Name);
                    paramMappings.Add(node, newParam);
                    visited = base.VisitParameter(newParam);
                }

            }
            else
                visited = base.VisitParameter(node);
            Debug.Assert(NeedsTypeChange(visited.Type) == false);
            return visited;
        }

        protected override Expression VisitMember(MemberExpression node)
        { // if we see x.Name on the old type, substitute for new type
            Expression visited = node;

            if (NeedsTypeChange(node.Type) || NeedsTypeChange(node.Member.DeclaringType))
            {
                var newtype = VisitType(node.Member.DeclaringType);

                var visitedExpression = Visit(node.Expression);
                Debug.Assert(NeedsTypeChange(visitedExpression.Type) == false);
                var targetProperty = newtype.GetProperty(node.Member.Name).GetGetMethod();
                //BindingFlags.Instance | BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic).Single();
                visited = Expression.Property(visitedExpression, targetProperty);
                Debug.Assert(NeedsTypeChange(visited.Type) == false);
                visited = base.VisitMember((MemberExpression)visited);
            }
            else
                visited = base.VisitMember(node);
            Debug.Assert(NeedsTypeChange(visited.Type) == false);
            return visited;
        }
    }
}
